{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 自定义层\n",
    "\n",
    "深度学习成功背后的一个因素是，可以用创造性的方式组合广泛的层，从而设计出适用于各种任务的结构。例如，研究人员发明了专门用于处理图像、文本、序列数据和执行动态编程的层。早晚有一天，你会遇到或要自己发明一个在深度学习框架中还不存在的层。在这些情况下，你必须构建自定义层。在本节中，我们将向你展示如何操作。\n",
    "\n",
    "## 不带参数的层\n",
    "\n",
    "首先，我们构造一个没有任何参数的自定义层。如果你还记得我们在 :numref:`sec_model_construction` 对块的介绍，这应该看起来很眼熟。下面的 `CenteredLayer` 类要从其输入中减去均值。要构建它，我们只需继承基础层类并实现正向传播功能。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CenteredLayer extends AbstractBlock {\n",
    "\n",
    "    @Override\n",
    "    protected NDList forwardInternal(\n",
    "            ParameterStore parameterStore,\n",
    "            NDList inputs,\n",
    "            boolean training,\n",
    "            PairList<String, Object> params) {\n",
    "        NDList current = inputs;\n",
    "        // Subtract the mean from the input\n",
    "        return new NDList(current.head().sub(current.head().mean()));\n",
    "    }\n",
    "    \n",
    "    @Override\n",
    "    public Shape[] getOutputShapes(Shape[] inputs) {\n",
    "        // Output shape should be the same as input\n",
    "        return inputs;\n",
    "    } \n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "让我们通过向其提供一些数据来验证该层是否按预期工作。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();\n",
    "\n",
    "CenteredLayer layer = new CenteredLayer();\n",
    "\n",
    "Model model = Model.newInstance(\"centered-layer\");\n",
    "model.setBlock(layer);\n",
    "\n",
    "Predictor<NDList, NDList> predictor = model.newPredictor(new NoopTranslator());\n",
    "NDArray input = manager.create(new float[]{1f, 2f, 3f, 4f, 5f});\n",
    "predictor.predict(new NDList(input)).singletonOrThrow();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在，我们可以将层作为组件合并到构建更复杂的模型中。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SequentialBlock net = new SequentialBlock();\n",
    "net.add(Linear.builder().setUnits(128).build());\n",
    "net.add(new CenteredLayer());\n",
    "net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);\n",
    "net.initialize(manager, DataType.FLOAT32, input.getShape());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "作为额外的健全性检查，我们可以向网络发送随机数据后，检查均值是否为0。由于我们处理的是浮点数，因为存储精度的原因，我们仍然可能会看到一个非常小的非零数。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDArray input = manager.randomUniform(-0.07f, 0.07f, new Shape(4, 8));\n",
    "NDArray y = predictor.predict(new NDList(input)).singletonOrThrow();\n",
    "y.mean();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 带参数的图层\n",
    "\n",
    "既然我们知道了如何定义简单的层，那么让我们继续定义具有参数的层，这些参数可以通过训练进行调整。我们可以使用内置函数来创建参数，这些参数提供一些基本的管理功能。比如管理访问、初始化、共享、保存和加载模型参数。这样做的好处之一是，我们不需要为每个自定义层编写自定义序列化程序。\n",
    "\n",
    "现在，让我们实现自定义版本的全连接层。回想一下，该层需要两个参数，一个用于表示权重，另一个用于表示偏置项。在此实现中，我们使用 `Activation.relu` 作为激活函数。该层需要输入参数：`inUnits` 和 `outUnits`，分别表示输入和输出的数量。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyLinear extends AbstractBlock {\n",
    "\n",
    "    private Parameter weight;\n",
    "    private Parameter bias;\n",
    "    \n",
    "    private int inUnits;\n",
    "    private int outUnits;\n",
    "\n",
    "    // outUnits: the number of outputs in this layer \n",
    "    // inUnits: the number of inputs in this layer\n",
    "    public MyLinear(int outUnits, int inUnits) {\n",
    "        this.inUnits = inUnits;\n",
    "        this.outUnits = outUnits;\n",
    "        weight = addParameter(\n",
    "            Parameter.builder()\n",
    "                .setName(\"weight\")\n",
    "                .setType(Parameter.Type.WEIGHT)\n",
    "                .optShape(new Shape(inUnits, outUnits))\n",
    "                .build());\n",
    "        bias = addParameter(\n",
    "            Parameter.builder()\n",
    "                .setName(\"bias\")\n",
    "                .setType(Parameter.Type.BIAS)\n",
    "                .optShape(new Shape(outUnits))\n",
    "                .build());\n",
    "    }\n",
    "    \n",
    "    @Override\n",
    "    protected NDList forwardInternal(\n",
    "            ParameterStore parameterStore,\n",
    "            NDList inputs,\n",
    "            boolean training,\n",
    "            PairList<String, Object> params) {\n",
    "        NDArray input = inputs.singletonOrThrow();\n",
    "        Device device = input.getDevice();\n",
    "        // Since we added the parameter, we can now access it from the parameter store\n",
    "        NDArray weightArr = parameterStore.getValue(weight, device, false);\n",
    "        NDArray biasArr = parameterStore.getValue(bias, device, false);\n",
    "        return relu(linear(input, weightArr, biasArr));\n",
    "    }\n",
    "    \n",
    "    // Applies linear transformation\n",
    "    public static NDArray linear(NDArray input, NDArray weight, NDArray bias) {\n",
    "        return input.dot(weight).add(bias);\n",
    "    }\n",
    "    \n",
    "    // Applies relu transformation\n",
    "    public static NDList relu(NDArray input) {\n",
    "        return new NDList(Activation.relu(input));\n",
    "    }\n",
    "    \n",
    "    @Override\n",
    "    public Shape[] getOutputShapes(Shape[] inputs) {\n",
    "        return new Shape[]{new Shape(outUnits, inUnits)};\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，我们实例化`MyLinear`类并访问其模型参数。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// 5 units in -> 3 units out\n",
    "MyLinear linear = new MyLinear(3, 5); \n",
    "var params = linear.getParameters();\n",
    "for (Pair<String, Parameter> param : params) {\n",
    "    System.out.println(param.getKey());\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以使用自定义层直接执行正向传播计算。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDArray input = manager.randomUniform(0, 1, new Shape(2, 5));\n",
    "\n",
    "linear.initialize(manager, DataType.FLOAT32, input.getShape());\n",
    "\n",
    "Model model = Model.newInstance(\"my-linear\");\n",
    "model.setBlock(linear);\n",
    "\n",
    "Predictor<NDList, NDList> predictor = model.newPredictor(new NoopTranslator());\n",
    "predictor.predict(new NDList(input)).singletonOrThrow();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们还可以使用自定义层构建模型。我们可以像使用内置的全连接层一样使用自定义层。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDArray input = manager.randomUniform(0, 1, new Shape(2, 64));\n",
    "\n",
    "SequentialBlock net = new SequentialBlock();\n",
    "net.add(new MyLinear(8, 64)); // 64 units in -> 8 units out\n",
    "net.add(new MyLinear(1, 8)); // 8 units in -> 1 unit out\n",
    "net.initialize(manager, DataType.FLOAT32, input.getShape());\n",
    "\n",
    "Model model = Model.newInstance(\"lin-reg-custom\");\n",
    "model.setBlock(net);\n",
    "\n",
    "Predictor<NDList, NDList> predictor = model.newPredictor(new NoopTranslator());\n",
    "predictor.predict(new NDList(input)).singletonOrThrow();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* 我们可以通过基本层类设计自定义层。这允许我们定义灵活的新层，其行为与库中的任何现有层不同。\n",
    "* 在自定义层定义完成后，就可以在任意环境和网络结构中调用该自定义层。\n",
    "* 层可以有局部参数，这些参数可以通过内置函数创建。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 设计一个接受输入并计算 `NDArray` 汇总的层，它返回$y_k = \\sum_{i, j} W_{ijk} x_i x_j$。\n",
    "1. 设计一个返回输入数据的傅立叶系数前半部分的层。\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
